{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python38564bitvenvvenv9efd950bbdd2494f8072faf3b588558e",
   "display_name": "Python 3.8.5 64-bit ('.venv': venv)",
   "language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "# Search image in the Dataset\n",
    "\n",
    "With this notebook you can search for photos using natural language."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "## Load the dataset\n",
    "\n",
    "You will need the Unsplash Dataset and the precomputed feature vetors for this. You don't want to process the whole dataset yourself, you can download the preprocessed feature vectors from [here](https://drive.google.com/drive/folders/1ozUUr8UQ2YWDSJyYIIun9V8Qg1TjqXEs?usp=sharing)."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# Set the paths\n",
    "dataset_version = \"lite\"  # choose \"lite\" or \"full\"\n",
    "unsplash_dataset_path = Path(\"unsplash-dataset\") / dataset_version\n",
    "features_path = Path(\"unsplash-dataset\") / dataset_version / \"features\"\n",
    "\n",
    "# Read the photos table\n",
    "photos = pd.read_csv(unsplash_dataset_path / \"photos.tsv000\", sep='\\t', header=0)\n",
    "\n",
    "# Load the features and the corresponding IDs\n",
    "photo_features = np.load(features_path / \"features.npy\")\n",
    "photo_ids = pd.read_csv(features_path / \"photo_ids.csv\")\n",
    "photo_ids = list(photo_ids['photo_id'])"
   ]
  },
  {
   "source": [
    "Load the CLIP network."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import clip\n",
    "import torch\n",
    "\n",
    "# Load the open CLIP model\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model, preprocess = clip.load(\"ViT-B/32\", device=device)"
   ]
  },
  {
   "source": [
    "## Search\n",
    "\n",
    "Specify your search query and encode it to a feature vector using CLIP."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "search_query = \"Two dogs playing in the snow\"\n",
    "\n",
    "with torch.no_grad():\n",
    "    # Encode and normalize the description using CLIP\n",
    "    text_encoded = model.encode_text(clip.tokenize(search_query).to(device))\n",
    "    text_encoded /= text_encoded.norm(dim=-1, keepdim=True)"
   ]
  },
  {
   "source": [
    "Compare the text features to the image features and find the best match."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve the description vector and the photo vectors\n",
    "text_features = text_encoded.cpu().numpy()\n",
    "\n",
    "# Compute the similarity between the descrption and each photo using the Cosine similarity\n",
    "similarities = list((text_features @ photo_features.T).squeeze(0))\n",
    "\n",
    "# Sort the photos by their similarity score\n",
    "best_photos = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)"
   ]
  },
  {
   "source": [
    "## Display the results\n",
    "\n",
    "Show the top 3 results."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/html": "<img src=\"https://images.unsplash.com/photo-1577366761509-937637f02454?w=640\"/>",
      "text/plain": "<IPython.core.display.Image object>"
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "<IPython.core.display.HTML object>",
      "text/html": "Photo by <a href=\"https://unsplash.com/@lukavovk?utm_source=NaturalLanguageImageSearch&utm_medium=referral\">Luka Vovk</a> on <a href=\"https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral\">Unsplash</a>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/html": "<img src=\"https://images.unsplash.com/photo-1577366761509-937637f02454?w=640\"/>",
      "text/plain": "<IPython.core.display.Image object>"
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "<IPython.core.display.HTML object>",
      "text/html": "Photo by <a href=\"https://unsplash.com/@lukavovk?utm_source=NaturalLanguageImageSearch&utm_medium=referral\">Luka Vovk</a> on <a href=\"https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral\">Unsplash</a>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/html": "<img src=\"https://images.unsplash.com/photo-1583762713699-7a6d1b8b6679?w=640\"/>",
      "text/plain": "<IPython.core.display.Image object>"
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "<IPython.core.display.HTML object>",
      "text/html": "Photo by <a href=\"https://unsplash.com/@tadekl?utm_source=NaturalLanguageImageSearch&utm_medium=referral\">Tadeusz Lakota</a> on <a href=\"https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral\">Unsplash</a>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from IPython.display import Image\n",
    "from IPython.core.display import HTML\n",
    "\n",
    "# Iterate over the top 3 results\n",
    "for i in range(3):\n",
    "    # Retrieve the photo ID\n",
    "    idx = best_photos[i][1]\n",
    "    photo_id = photo_ids[idx]\n",
    "\n",
    "    # Get all metadata for this photo\n",
    "    photo_data = photos[photos[\"photo_id\"] == photo_id].iloc[0]\n",
    "\n",
    "    # Display the photo\n",
    "    display(Image(url=photo_data[\"photo_image_url\"] + \"?w=640\"))\n",
    "\n",
    "    # Display the attribution text\n",
    "    display(HTML(f'Photo by <a href=\"https://unsplash.com/@{photo_data[\"photographer_username\"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral\">{photo_data[\"photographer_first_name\"]} {photo_data[\"photographer_last_name\"]}</a> on <a href=\"https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral\">Unsplash</a>'))\n",
    "    print()\n",
    "    "
   ]
  }
 ]
}